Recover binary search tree

Time: O(N); Space: O(1); hard

Two elements of a binary search tree (BST) are swapped by mistake.

Recover the tree without changing its structure.

Example 1:

  1
 /
3
 \
  2

Input: root = {TreeNode} [1,3,null,null,2]

  3
 /
1
 \
  2

Output: {TreeNode} [3,1,null,null,2]

Example 2:

  3
 / \
1   4
   /
  2

Input: root = {TreeNode} [3,1,4,null,null,2]

  2
 / \
1   4
   /
  3

Output: {TreeNode} [2,1,4,null,null,3]

Follow up:

A solution using O(n) space is pretty straight forward. Could you devise a constant space solution?

[1]:
class TreeNode(object):
    '''
    Definition for a  binary tree node
    '''
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

    def __repr__(self):
        if self:
            serial = []
            queue = [self]

            while queue:
                cur = queue[0]

                if cur:
                    serial.append(cur.val)
                    queue.append(cur.left)
                    queue.append(cur.right)
                else:
                    serial.append("#")

                queue = queue[1:]

            while serial[-1] == "#":
                serial.pop()

            return repr(serial)

        else:
            return None

Auxiliary Tools

[2]:
from graphviz import Graph

class TreeTasks(object):

    def visualize_tree(self, tree):
        def add_nodes_edges(tree, dot=None):
            # Create Graph object
            if dot is None:
                dot = Graph()
                dot.node(name=str(tree), label=str(tree.val))
            # Add nodes
            if tree.left:
                dot.node(name=str(tree.left), label="." + str(tree.left.val))
                dot.edge(str(tree), str(tree.left))
                dot = add_nodes_edges(tree.left, dot=dot)
            if tree.right:
                dot.node(name=str(tree.right), label=str(tree.right.val) + ".")
                dot.edge(str(tree), str(tree.right))
                dot = add_nodes_edges(tree.right, dot=dot)
            return dot

        # Add nodes recursively and create a list of edges
        dot = add_nodes_edges(tree)
        # Visualize the graph
        display(dot)

        return dot

Solution

[3]:
class Solution1(object):
    def recoverTree(self, root):
        '''
        :type root: TreeNode
        :rtype: TreeNode
        '''
        return self.MorrisTraversal(root)

    def MorrisTraversal(self, root):
        if root is None:
            return
        broken = [None, None]
        pre, cur = None, root

        while cur:
            if cur.left is None:
                self.detectBroken(broken, pre, cur)
                pre = cur
                cur = cur.right
            else:
                node = cur.left
                while node.right and node.right != cur:
                    node = node.right

                if node.right is None:
                    node.right =cur
                    cur = cur.left
                else:
                    self.detectBroken(broken, pre, cur)
                    node.right = None
                    pre = cur
                    cur = cur.right

        broken[0].val, broken[1].val = broken[1].val, broken[0].val

        return root

    def detectBroken(self, broken, pre, cur):
        if pre and pre.val > cur.val:
            if broken[0] is None:
                broken[0] = pre
            broken[1] = cur
[4]:
root = TreeNode(1)
root.left = TreeNode(3)
root.left.right = TreeNode(2)
s = Solution1()
tree = s.recoverTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
../../_images/topics_tree_0099_recover_binary_search_tree_[O(N),O(1),hard]_6_0.svg
[5]:
root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.right.left = TreeNode(2)
s = Solution1()
tree = s.recoverTree(root)
t = TreeTasks()
dot = t.visualize_tree(tree)
../../_images/topics_tree_0099_recover_binary_search_tree_[O(N),O(1),hard]_7_0.svg